package org.hepeng.workx.mybatis.executor;


import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.builder.annotation.ProviderSqlSource;
import org.apache.ibatis.exceptions.ExceptionFactory;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.scripting.xmltags.DynamicSqlSource;
import org.hepeng.workx.mybatis.event.DeleteEvent;
import org.hepeng.workx.mybatis.event.ExecuteErrorEvent;
import org.hepeng.workx.mybatis.event.ExecuteEvent;
import org.hepeng.workx.mybatis.event.publisher.ExecuteEventPublisher;
import org.hepeng.workx.mybatis.event.FlushCacheEvent;
import org.hepeng.workx.mybatis.event.FlushEvent;
import org.hepeng.workx.mybatis.event.InsertEvent;
import org.hepeng.workx.mybatis.event.QueryCursorEvent;
import org.hepeng.workx.mybatis.event.QueryEvent;
import org.hepeng.workx.mybatis.event.TxCommitEvent;
import org.hepeng.workx.mybatis.event.TxRollbackEvent;
import org.hepeng.workx.mybatis.event.UpdateEvent;
import org.hepeng.workx.mybatis.sql.SQLStatementInfo;
import org.hepeng.workx.util.proxy.Invocation;
import org.hepeng.workx.util.proxy.Invoker;
import org.joor.Reflect;
import org.springframework.util.Assert;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author he peng
 */
public class EventPublishExecutorProxy implements Invoker {

    private static final Map<String , SQLStatementInfo> SQL_STATEMENT_INFO_CACHE = new ConcurrentHashMap<>();

    private ExecuteEventPublisher publisher;

    public EventPublishExecutorProxy(ExecuteEventPublisher publisher) {
        Assert.notNull(publisher , "publisher must not be null");
        this.publisher = publisher;
    }

    @Override
    public Object invoke(Invocation invocation) throws Throwable {
        if (Object.class.equals(invocation.getMethod().getDeclaringClass())) {
            try {
                return invocation.invoke();
            } catch (Throwable t) {
                ErrorContext.instance().cause(t);
                throw ExceptionFactory.wrapException("" , (Exception) t);
            }
        }

        Object target = invocation.getProxy();
        Object[] args = invocation.getArgs();
        Method method = invocation.getMethod();
        preInvoke(invocation);
        Object result;
        try {
            result = invocation.invoke();
            postInvoke(target , args , result);
        } catch (Throwable t) {
            publishErrorEvent(t);
            throw ExceptionFactory.wrapException("" , (Exception) t);
        } finally {
            publishEvent(method.getName());
            clearContext(method.getName());
        }

        return result;
    }

    private void clearContext(String methodName) {
        if (StringUtils.equalsAny(methodName ,
                "commit" , "rollback" , "update" , "query" , "queryCursor")) {
            SQLExecuteContext.close();
        }
    }

    protected void postInvoke(Object obj, Object[] args, Object result) {
        SQLExecuteContext.getContext().set("result" , result);
    }

    protected void preInvoke(Invocation invocation) {
        SQLExecuteContext context = SQLExecuteContext.getContext();
        Object[] args = invocation.getArgs();
        context.set("args" , args);
        if (ArrayUtils.isNotEmpty(args)) {
            for (Object arg : args) {
                if (MappedStatement.class.isAssignableFrom(arg.getClass())) {
                    MappedStatement ms = (MappedStatement) arg;
                    context.set("sql" , ms.getBoundSql(args[1]).getSql());
                    context.set("sqlCommandType" , ms.getSqlCommandType());
                    context.set("flushCacheRequired" , ms.isFlushCacheRequired());
                    if (StringUtils.equalsAny(invocation.getMethod().getName() ,
                            "update" , "query" , "queryCursor")) {
                        context.set("sqlStatementInfo" , parseSQlStatement(ms , args[1]));
                    }
                    break;
                }
            }
        }
    }

    private SQLStatementInfo parseSQlStatement(MappedStatement ms , Object parameter) {
        SQLStatementInfo sqlStatementInfo;
        if (DynamicSqlSource.class.isAssignableFrom(ms.getSqlSource().getClass())
                || ProviderSqlSource.class.isAssignableFrom(ms.getSqlSource().getClass())) {
            sqlStatementInfo = doSQLParse(ms , parameter);
        } else {
            String sql = ms.getBoundSql(null).getSql();
            sqlStatementInfo = SQL_STATEMENT_INFO_CACHE.get(sql);
            if (Objects.isNull(sqlStatementInfo)) {
                sqlStatementInfo = doSQLParse(ms , null);
                SQL_STATEMENT_INFO_CACHE.put(sql , sqlStatementInfo);
            }
        }
        return sqlStatementInfo;
    }

    private SQLStatementInfo doSQLParse(MappedStatement ms , Object parameter) {
        String sql = ms.getBoundSql(parameter).getSql();
        SQLStatementParser sqlStatementParser = new SQLStatementParser(sql);
        SchemaStatVisitor schemaStatVisitor = Reflect.on(ms.getConfiguration()).get("schemaStatVisitor");
        SQLStatement sqlStatement = sqlStatementParser.parseStatement(true);
        sqlStatement.accept(schemaStatVisitor);

        return SQLStatementInfo.builder()
                .sql(sql)
                .sqlCommandType(ms.getSqlCommandType())
                .tables(schemaStatVisitor.getTables())
                .columns(schemaStatVisitor.getColumns())
                .build();
    }

    private void publishEvent(String methodName) {
        publishUpdateEvent(methodName);
        publishQueryEvent(methodName);
        publishCommitEvent(methodName);
        publishRollbackEvent(methodName);
    }

    private void publishRollbackEvent(String methodName) {
        if (! StringUtils.equals(methodName , "rollback")) {
            return;
        }
        ExecuteEvent event = new TxRollbackEvent(
                System.currentTimeMillis() , null , SQLExecuteContext.getContext());
        this.publisher.publishEvent(event);
    }

    private void publishCommitEvent(String methodName) {
        if (! StringUtils.equals(methodName ,"commit")) {
            return;
        }
        ExecuteEvent event = new TxCommitEvent(
                System.currentTimeMillis() , null , SQLExecuteContext.getContext());
        this.publisher.publishEvent(event);
    }

    private void publishQueryEvent(String methodName) {
        if (! StringUtils.equals(methodName , "query")) {
            return;
        }

        SQLExecuteContext context = SQLExecuteContext.getContext();
        boolean isSelectSql = Objects.equals(SqlCommandType.SELECT, context.getSqlCommandType());
        if (isSelectSql) {
            ExecuteEvent event = new QueryEvent(System.currentTimeMillis() , null , context);
            this.publisher.publishEvent(event);
        }

        if (isSelectSql && StringUtils.containsIgnoreCase(methodName , "querycursor")) {
            ExecuteEvent event = new QueryCursorEvent(System.currentTimeMillis() , null , context);
            this.publisher.publishEvent(event);
        }
    }


    private void publishErrorEvent(Throwable t) {
        this.publisher.publishEvent(new ExecuteErrorEvent(
                        System.currentTimeMillis() ,
                        null ,
                        SQLExecuteContext.getContext() , t));
    }

    private void publishUpdateEvent(String methodName) {
        if (! StringUtils.equals(methodName , "update")) {
            return;
        }

        SQLExecuteContext context = SQLExecuteContext.getContext();
        SqlCommandType sqlCommandType = context.getSqlCommandType();
        if (Objects.isNull(sqlCommandType)) {
            return;
        }

        ExecuteEvent event;
        switch (sqlCommandType) {
            case INSERT: {
                event = new InsertEvent(System.currentTimeMillis() , null , context);
            } break;
            case UPDATE: {
                event = new UpdateEvent(System.currentTimeMillis() , null , context);
            } break;
            case DELETE: {
                event = new DeleteEvent(System.currentTimeMillis() , null , context);
            } break;
            case FLUSH: {
                event = new FlushEvent(System.currentTimeMillis() , null , context);
            } break;
            default: event = null; break;
        }

        if (context.getFlushCacheRequired()) {
            ExecuteEvent flushCacheEvent = new FlushCacheEvent(System.currentTimeMillis() , null , context);
            this.publisher.publishEvent(flushCacheEvent);
        }
        this.publisher.publishEvent(event);
    }
}
